#!/usr/bin/python
# Trivial toy implementation of forward-backward
# Follows Jason Eisner's SpreadSheet-based teaching tool, 
#  and produces the same results.
import simplejson
import pprint

class forwardBackward( ):

    def __init__( self, obs, init_file, state2obs_file, state2state_file, final_file ):
        """Initialize the probabilitiy matrices."""
        self.alphas      = []
        self.betas       = []
        self.init_mat    = self._load_json_from_file(init_file)
        self.state2obs   = self._load_json_from_file(state2obs_file)
        self.state2state = self._load_json_from_file(state2state_file)
        self.final_mat   = self._load_json_from_file(final_file)
        self.obs         = obs.split(" ")
        #Temporary receptacles for re-estimated values
        self.re_init_mat    = {}
        self.re_state2obs   = {}
        self.re_state2state = {}
        self.re_final_mat   = {}
        
    def _load_json_from_file( self, infile ):
        """Load a json object from a file."""
        json_obj = simplejson.loads(open(infile,"r").read( ))
        return json_obj

    def _init_alphas( self ):
        """
           Initialize the forward alpha probabilities.
        """
        for state in self.init_mat:
            prob   = self.init_mat[state] * self.state2obs[state][self.obs[0]]
            if len(self.alphas)>0:
                self.alphas[0][state] = [prob, "START %s" % (state)]
            else:
                self.alphas = [ {} for i in xrange(len(self.obs)) ]
                self.alphas[0] = { state : [prob, "START %s" % (state)] }

        return

    def _init_betas( self ):
        """
           Initialize the beta backward probabilities.
        """
        self.betas = [ {} for i in xrange(len(self.obs)) ]
        for state in self.final_mat:
            self.betas[len(self.obs)-1][state] = [self.final_mat[state], "END %s %s" % (state, state)]
        return

    def _calc_forward_alphas( self ):
        """
           Calculate all of the forward alpha probabilities
           The partial best paths:
           GIVEN the observation for stage N
           FOREACH state Y AT stage N,
             FOREACH state X AT stage N-1
               CALCULATE prob(state Y | observation) *
                         prob(state Y | previous state at stage N-1 was X) *
                         prob(partial best path from stage N-1)
             RECORD total probabiity of path to state Y AT stage N
        """
        for i,o in enumerate(self.obs):
            if i==0: continue
            max_prob  = 0; max_state = 0
            for curr in self.state2obs:
                curr_prob = 0; max_prob = 0; max_state = 0
                xmax = 0; tmps = None
                for prev in self.state2state:
                    val, states = self.alphas[i-1][prev]
                    subtot = self.state2obs[curr][self.obs[i]] \
                        * self.state2state[prev][curr] \
                        * val
                    curr_prob += subtot
                    if subtot > xmax:
                        xmax = subtot
                        tmps = states
                if curr_prob==0.0:
                    print "ERROR: curr_prob==0.0.  Failed to reach final state for the current iteration."
                    sys.exit()
                if curr_prob > max_prob:
                    max_state = tmps
                    max_prob  = curr_prob
                self.alphas[i][curr] = [max_prob, "%s %s" % (max_state, curr)]
        return

    def _calc_backward_betas( self ):
        """
           Same process as _calc_forward_alphas, but in reverse.
        """
        for i in xrange(len(self.obs)-2,-1,-1):
            max_prob = 0; max_state = 0
            for curr in self.state2obs:
                curr_prob = 0; max_prob = 0; max_state = 0
                xmax = 0; tmps = None
                for prev in self.state2state:
                    val, states = self.betas[i+1][prev]
                    subtot = self.state2obs[prev][self.obs[i+1]] \
                        * self.state2state[curr][prev] \
                        * val
                    curr_prob += subtot
                    if subtot > xmax:
                        xmax = subtot
                        tmps = states
                if curr_prob > max_prob:
                    max_state = tmps
                    max_prob  = curr_prob
                self.betas[i][curr] = [max_prob, "%s %s" % (max_state, curr)]
        return

    def best_path( self ):
        """Viterbi best path through the lattice."""
        top_prob = 0.0
        best_path = None
        for state in self.alphas[len(self.alphas)-1]:
            val, states = self.alphas[len(self.alphas)-1][state]
            if val > top_prob:
                top_prob = val
                best_path = states
        print "%s\t%s" %(best_path,str(top_prob))
        return top_prob

    def _reestimate_probs( self ):
        """
           Reestimate all the transition probabilities.

           This is a 3-step process,

           1. Compute the alpha-beta values, normalize and store them
           2. Re-normalize the init, s2s, s2o and final matrices using the
               results of 1.
        """
        
        reest = {}
        gtot  = 0.0
        ssum  = 0.0
        state_totals = {}
        
        #Iterate through the alphas and betas and compute 
        # the combined alpha-beta values
        for j,val in enumerate(self.alphas):
            sum = 0.0
            ab_vals = {}
            for key in self.alphas[j]:
                alpha = self.alphas[j][key]
                beta  = self.betas[j][key]
                alphaBeta = alpha[0] * beta[0]
                ab_vals[key]  = alphaBeta
                sum      += alphaBeta
            ssum = sum
            for key in ab_vals:
                total = ab_vals[key] / sum
                if reest.has_key(key):
                    if reest[key].has_key(self.obs[j]):
                        reest[key][self.obs[j]] += total
                    else:
                        reest[key][self.obs[j]]  = total
                else:
                    reest[key] = { self.obs[j]:total }
                    
                if state_totals.has_key(key):
                    state_totals[key] += total
                else:
                    state_totals[key]  = total
                gtot += total
                
        #Compute the reestimated state-2-observation matrix
        for key in reest:
            for s1 in reest[key]:
                reestimated = reest[key][s1] / state_totals[key]
                if self.re_state2obs.has_key(key):
                    self.re_state2obs[key][s1] = reestimated
                else:
                    self.re_state2obs[key] = {s1:reestimated}

        #Compute the reestimated init matrix
        for key in self.alphas[0]:
            alpha     = self.alphas[0][key]
            beta      = self.betas[0][key]
            alphaBeta = alpha[0] * beta[0]
            tri       = alphaBeta / ssum
            self.re_init_mat[key] = tri

        #Compute the reestimated state-2-state matrix
        transitions = {}
        for j,o in enumerate(self.obs):
            if j==0: continue
            for s1 in self.state2state:
                for s2 in  self.state2state[s1]:
                    al    = self.alphas[j-1][s1][0]
                    be    = self.betas[j][s2][0]
                    trans = self.state2state[s1][s2]
                    conf  = self.state2obs[s2][self.obs[j]]
                    stot  = al * be * trans * conf / ssum
                    if transitions.has_key(s1):
                        if transitions[s1].has_key(s2):
                            transitions[s1][s2] += stot
                        else:
                            transitions[s1][s2] = stot
                    else:
                        transitions[s1] = {s2:stot}

        for s1 in transitions:
            for s2 in transitions[s1]:
                newP = transitions[s1][s2] / state_totals[s1]
                if self.re_state2state.has_key(s1):
                    self.re_state2state[s1][s2] = newP
                else:
                    self.re_state2state[s1] = {s2:newP}

        #Compute the reestimated final matrix
        for key in self.alphas[len(self.alphas)-1]:
            alpha     = self.alphas[len(self.alphas)-1][key]
            beta      = self.betas[len(self.alphas)-1][key]
            alphaBeta = alpha[0] * beta[0]
            tri = alphaBeta / ssum / state_totals[key]
            self.re_final_mat[key] = tri

        self.init_mat    = self.re_init_mat
        self.state2obs   = self.re_state2obs
        self.state2state = self.re_state2state
        self.final_mat   = self.re_final_mat
        #Reset the alphas and betas
        self.alphas      = []
        self.betas       = []
        return

    def iterate( self, n_iter=5, ratio=4e-20,verbose=False ):
        """Run the algorithm iteratively until convergence."""
        prev_likelihood = 9999999
        for i in range(n_iter):
            if verbose:
                print "Iteration: %d" % i
                print "S2O"
                pprint.pprint(fb.state2obs)
                print "S2S"
                pprint.pprint(fb.state2state)
            fb._init_alphas()
            fb._init_betas()
            fb._calc_forward_alphas()
            fb._calc_backward_betas()

            likelihood = fb.best_path()
            if abs(likelihood - prev_likelihood)<ratio:
                print "Achieved convergence ratio:", ratio,"; Stopping at iteration: %d." % i
                break
            elif verbose:
                print "Likelihood change:", abs(likelihood - prev_likelihood)
            fb._reestimate_probs()            
            prev_likelihood = likelihood
        return

if __name__=="__main__":
    import sys, argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--init',    "-i", help='2D Initialization probability matrix. JSON format.', required=True)
    parser.add_argument('--s2o',     "-o", help='3D state-to-observation transition matrix. JSON format.', required=True)
    parser.add_argument('--s2s',     "-s", help='3d state-to-state transition matrix. JSON format.', required=True)
    parser.add_argument('--final',   "-f", help='2D Final probability matrix. JSON format.', required=True)
    parser.add_argument('--obs',     "-b", help='Observation sequence.  Must correspond to observations values in state-to-observation matrix.')
    parser.add_argument('--n_iter',  "-n", type=int,   default=10,    help='Maximum number of iterations for forward-backward algorithm.')
    parser.add_argument('--ratio',   "-r", type=float, default=4e-20, help='Convergence ratio.')
    parser.add_argument('--verbose', "-v", type=bool,  default=False, help="Verbosity.")
    args = parser.parse_args()

    fb = forwardBackward(args.obs, args.init, args.s2o, args.s2s, args.final)
    fb.iterate(n_iter=args.n_iter, ratio=args.ratio, verbose=args.verbose)
